{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "D3-cg2_rYfe6"
   },
   "source": [
    "# Ungraded Lab: Quantization and Pruning\n",
    "\n",
    "In this lab, you will get some hands-on practice with the mobile optimization techniques discussed in the lectures. These enable reduced model size and latency which makes it ideal for edge and IOT devices. You will start by training a Keras model then compare its model size and accuracy after going through these techniques:\n",
    "\n",
    "* post-training quantization\n",
    "* quantization aware training\n",
    "* weight pruning\n",
    "\n",
    "Let's begin!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0gRaAOIsba55"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4nVRm10UNHZ9"
   },
   "source": [
    "Let's first import a few common libraries that you'll be using throughout the notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9sL5kmRZbZxX"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import os\n",
    "import tempfile\n",
    "import zipfile"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GS5gXwABm7XP"
   },
   "source": [
    "<a name='utilities'>\n",
    "\n",
    "## Utilities and constants\n",
    "\n",
    "Let's first define a few string constants and utility functions to make our code easier to maintain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nEuiXyPZMKQm"
   },
   "outputs": [],
   "source": [
    "# GLOBAL VARIABLES\n",
    "\n",
    "# String constants for model filenames\n",
    "FILE_WEIGHTS = 'baseline_weights.h5'\n",
    "FILE_NON_QUANTIZED_H5 = 'non_quantized.h5'\n",
    "FILE_NON_QUANTIZED_TFLITE = 'non_quantized.tflite'\n",
    "FILE_PT_QUANTIZED = 'post_training_quantized.tflite'\n",
    "FILE_QAT_QUANTIZED = 'quant_aware_quantized.tflite'\n",
    "FILE_PRUNED_MODEL_H5 = 'pruned_model.h5'\n",
    "FILE_PRUNED_QUANTIZED_TFLITE = 'pruned_quantized.tflite'\n",
    "FILE_PRUNED_NON_QUANTIZED_TFLITE = 'pruned_non_quantized.tflite'\n",
    "\n",
    "# Dictionaries to hold measurements\n",
    "MODEL_SIZE = {}\n",
    "ACCURACY = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pqdSGWccdk8G"
   },
   "outputs": [],
   "source": [
    "# UTILITY FUNCTIONS\n",
    "\n",
    "def print_metric(metric_dict, metric_name):\n",
    "  '''Prints key and values stored in a dictionary'''\n",
    "  for metric, value in metric_dict.items():\n",
    "    print(f'{metric_name} for {metric}: {value}')\n",
    "\n",
    "\n",
    "def model_builder():\n",
    "  '''Returns a shallow CNN for training on the MNIST dataset'''\n",
    "\n",
    "  keras = tf.keras\n",
    "\n",
    "  # Define the model architecture.\n",
    "  model = keras.Sequential([\n",
    "    keras.layers.InputLayer(input_shape=(28, 28)),\n",
    "    keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
    "    keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),\n",
    "    keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
    "    keras.layers.Flatten(),\n",
    "    keras.layers.Dense(10, activation='softmax')\n",
    "  ])\n",
    "\n",
    "  return model\n",
    "\n",
    "\n",
    "def evaluate_tflite_model(filename, x_test, y_test):\n",
    "  '''\n",
    "  Measures the accuracy of a given TF Lite model and test set\n",
    "  \n",
    "  Args:\n",
    "    filename (string) - filename of the model to load\n",
    "    x_test (numpy array) - test images\n",
    "    y_test (numpy array) - test labels\n",
    "\n",
    "  Returns\n",
    "    float showing the accuracy against the test set\n",
    "  '''\n",
    "\n",
    "  # Initialize the TF Lite Interpreter and allocate tensors\n",
    "  interpreter = tf.lite.Interpreter(model_path=filename)\n",
    "  interpreter.allocate_tensors()\n",
    "\n",
    "  # Get input and output index\n",
    "  input_index = interpreter.get_input_details()[0][\"index\"]\n",
    "  output_index = interpreter.get_output_details()[0][\"index\"]\n",
    "\n",
    "  # Initialize empty predictions list\n",
    "  prediction_digits = []\n",
    "  \n",
    "  # Run predictions on every image in the \"test\" dataset.\n",
    "  for i, test_image in enumerate(x_test):\n",
    "    # Pre-processing: add batch dimension and convert to float32 to match with\n",
    "    # the model's input data format.\n",
    "    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)\n",
    "    interpreter.set_tensor(input_index, test_image)\n",
    "\n",
    "    # Run inference.\n",
    "    interpreter.invoke()\n",
    "\n",
    "    # Post-processing: remove batch dimension and find the digit with highest\n",
    "    # probability.\n",
    "    output = interpreter.tensor(output_index)\n",
    "    digit = np.argmax(output()[0])\n",
    "    prediction_digits.append(digit)\n",
    "\n",
    "  # Compare prediction results with ground truth labels to calculate accuracy.\n",
    "  prediction_digits = np.array(prediction_digits)\n",
    "  accuracy = (prediction_digits == y_test).mean()\n",
    "  \n",
    "  return accuracy\n",
    "\n",
    "\n",
    "def get_gzipped_model_size(file):\n",
    "  '''Returns size of gzipped model, in bytes.'''\n",
    "  _, zipped_file = tempfile.mkstemp('.zip')\n",
    "  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:\n",
    "    f.write(file)\n",
    "\n",
    "  return os.path.getsize(zipped_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AxnjOqLpYawi"
   },
   "source": [
    "## Download and Prepare the Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rfC0D71tnVKr"
   },
   "source": [
    "You will be using the [MNIST](https://keras.io/api/datasets/mnist/) dataset which is hosted in [Keras Datasets](https://keras.io/api/datasets/). Some of the helper files in this notebook are made to work with this dataset so if you decide to switch to a different dataset, make sure to check if those helper functions need to be modified (e.g. shape of the Flatten layer in your model)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Z5f5Y08r0sob"
   },
   "outputs": [],
   "source": [
    "# Load MNIST dataset\n",
    "mnist = tf.keras.datasets.mnist\n",
    "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
    "\n",
    "# Normalize the input image so that each pixel value is between 0 to 1.\n",
    "train_images = train_images / 255.0\n",
    "test_images = test_images / 255.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Czvt9P1EYnQT"
   },
   "source": [
    "## Baseline Model\n",
    "\n",
    "You will first build and train a Keras model. This will be the baseline where you will be comparing the mobile optimized versions later on. This will just be a shallow CNN with a softmax output to classify a given MNIST digit. You can review the `model_builder()` function in the utilities at the top of this notebook but we also printed the model summary below to show the architecture. \n",
    "\n",
    "You will also save the weights so you can reinitialize the other models later the same way. This is not needed in real projects but for this demo notebook, it would be good to have the same initial state later so you can compare the effects of the optimizations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3Ild5juYXu4j"
   },
   "outputs": [],
   "source": [
    "# Create the baseline model\n",
    "baseline_model = model_builder()\n",
    "\n",
    "# Save the initial weights for use later\n",
    "baseline_model.save_weights(FILE_WEIGHTS)\n",
    "\n",
    "# Print the model summary\n",
    "baseline_model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "74y6LJMVYRCL"
   },
   "source": [
    "You can then compile and train the model. In practice, it's best to shuffle the train set but for this demo, it is set to `False` for reproducibility of the results. One epoch below will reach around 91% accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xViB61FuY0Pf"
   },
   "outputs": [],
   "source": [
    "# Setup the model for training\n",
    "baseline_model.compile(optimizer='adam',\n",
    "              loss='sparse_categorical_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "# Train the model\n",
    "baseline_model.fit(train_images, train_labels, epochs=1, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "47BgpWwOaR8b"
   },
   "source": [
    "Let's save the accuracy of the model against the test set so you can compare later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JQSVh1_t4Z2h"
   },
   "outputs": [],
   "source": [
    "# Get the baseline accuracy\n",
    "_, ACCURACY['baseline Keras model'] = baseline_model.evaluate(test_images, test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aAfbP3uua6bE"
   },
   "source": [
    "Next, you will save the Keras model as a file and record its size as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_A8WPjzqLbH3"
   },
   "outputs": [],
   "source": [
    "# Save the Keras model\n",
    "baseline_model.save(FILE_NON_QUANTIZED_H5, include_optimizer=False)\n",
    "\n",
    "# Save and get the model size\n",
    "MODEL_SIZE['baseline h5'] = os.path.getsize(FILE_NON_QUANTIZED_H5)\n",
    "\n",
    "# Print records so far\n",
    "print_metric(ACCURACY, \"test accuracy\")\n",
    "print_metric(MODEL_SIZE, \"model size in bytes\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ak8rBX-qX_KM"
   },
   "source": [
    "### Convert the model to TF Lite format"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PpkXpDy_OCzB"
   },
   "source": [
    "Next, you will convert the model to [Tensorflow Lite (TF Lite)](https://www.tensorflow.org/lite/guide) format. This is designed to make Tensorflow models more efficient and lightweight when running on mobile, embedded, and IOT devices. \n",
    "\n",
    "You can convert a Keras model with TF Lite's [Converter](https://www.tensorflow.org/lite/convert/index) class and we've incorporated it in the short helper function below. Notice that there is a `quantize` flag which you can use to quantize the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zQYM0A0SgCNS"
   },
   "outputs": [],
   "source": [
    "def convert_tflite(model, filename, quantize=False):\n",
    "  '''\n",
    "  Converts the model to TF Lite format and writes to a file\n",
    "\n",
    "  Args:\n",
    "    model (Keras model) - model to convert to TF Lite\n",
    "    filename (string) - string to use when saving the file\n",
    "    quantize (bool) - flag to indicate quantization\n",
    "\n",
    "  Returns:\n",
    "    None\n",
    "  '''\n",
    "  \n",
    "  # Initialize the converter\n",
    "  converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
    "\n",
    "  # Set for quantization if flag is set to True\n",
    "  if quantize:\n",
    "    converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
    "\n",
    "  # Convert the model\n",
    "  tflite_model = converter.convert()\n",
    "\n",
    "  # Save the model.\n",
    "  with open(filename, 'wb') as f:\n",
    "    f.write(tflite_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lQkC9plnP2pU"
   },
   "source": [
    "You will use the helper function to convert the Keras model then get its size and accuracy. Take note that this is *not yet* quantized."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5H61feiOZkcI"
   },
   "outputs": [],
   "source": [
    "# Convert baseline model\n",
    "convert_tflite(baseline_model, FILE_NON_QUANTIZED_TFLITE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "REf-EaQlQoYZ"
   },
   "source": [
    "You will notice that there is already a slight decrease in model size when converting to `.tflite` format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cmlNGwbCBo8v"
   },
   "outputs": [],
   "source": [
    "MODEL_SIZE['non quantized tflite'] = os.path.getsize(FILE_NON_QUANTIZED_TFLITE)\n",
    "\n",
    "print_metric(MODEL_SIZE, 'model size in bytes')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Rp-ndoNSRnvX"
   },
   "source": [
    "The accuracy will also be nearly identical when converting between formats. You can setup a TF Lite model for input-output using its [Interpreter](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter) class. This is shown in the `evaluate_tflite_model()` helper function provided in the `Utilities` section earlier.\n",
    "\n",
    "*Note: If you see a `Runtime Error: There is at least 1 reference to internal data in the interpreter in the form of a numpy array or slice.` , please try re-running the cell.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "OQFkh5ukiiZE"
   },
   "outputs": [],
   "source": [
    "ACCURACY['non quantized tflite'] = evaluate_tflite_model(FILE_NON_QUANTIZED_TFLITE, test_images, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CplCOws3jaB0"
   },
   "outputs": [],
   "source": [
    "print_metric(ACCURACY, 'test accuracy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "N6ilHiSGYCFL"
   },
   "source": [
    "### Post-Training Quantization\n",
    "\n",
    "Now that you have the baseline metrics, you can now observe the effects of quantization. As mentioned in the lectures, this process involves converting floating point representations into integer to reduce model size and achieve faster computation.\n",
    "\n",
    "As shown in the `convert_tflite()` helper function earlier, you can easily do [post-training quantization](https://www.tensorflow.org/lite/performance/post_training_quantization) with the TF Lite API. You just need to set the converter optimization and assign an [Optimize](https://www.tensorflow.org/api_docs/python/tf/lite/Optimize) Enum.\n",
    "\n",
    "You will set the `quantize` flag to do that and get the metrics again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DdWNTJ2J1OpL"
   },
   "outputs": [],
   "source": [
    "# Convert and quantize the baseline model\n",
    "convert_tflite(baseline_model, FILE_PT_QUANTIZED, quantize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cTFHf4Rw1bCJ"
   },
   "outputs": [],
   "source": [
    "# Get the model size\n",
    "MODEL_SIZE['post training quantized tflite'] = os.path.getsize(FILE_PT_QUANTIZED)\n",
    "\n",
    "print_metric(MODEL_SIZE, 'model size')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SYcBZduWVqOH"
   },
   "source": [
    "You should see around a 4X reduction in model size in the quantized version. This comes from converting the 32 bit representations (float) into 8 bits (integer).\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vhEYoQ83-pT_"
   },
   "outputs": [],
   "source": [
    "ACCURACY['post training quantized tflite'] = evaluate_tflite_model(FILE_PT_QUANTIZED, test_images, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4D0Srsjb_inn"
   },
   "outputs": [],
   "source": [
    "print_metric(ACCURACY, 'test accuracy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rGTzSOuQWG4L"
   },
   "source": [
    "As mentioned in the lecture, you can expect the accuracy to not be the same when quantizing the model. Most of the time it will decrease but in some cases, it can even increase. Again, this can be attributed to the loss of precision when you remove the extra bits from the float data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vFf1DDVnYIes"
   },
   "source": [
    "## Quantization Aware Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "37oAb7PuXK36"
   },
   "source": [
    "When post-training quantization results in loss of accuracy that is unacceptable for your application, you can consider doing [quantization aware training](https://www.tensorflow.org/model_optimization/guide/quantization/training) before quantizing the model. This simulates the loss of precision by inserting fake quant nodes in the model during training. That way, your model will learn to adapt with the loss of precision to get more accurate predictions.\n",
    "\n",
    "The [Tensorflow Model Optimization Toolkit](https://www.tensorflow.org/model_optimization) provides a [quantize_model()](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/quantization/keras/quantize_model) method to do this quickly and you will see that below. But first, let's install the toolkit into the notebook environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6WSt6OQGoNAt"
   },
   "outputs": [],
   "source": [
    "# Install the toolkit\n",
    "!pip install tensorflow_model_optimization==0.7.5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oYHmeMihYjnB"
   },
   "source": [
    "You will build the baseline model again but this time, you will pass it into the `quantize_model()` method to indicate quantization aware training.\n",
    "\n",
    "Take note that in case you decide to pass in a model that is already trained, then make sure to recompile before you continue training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3dGSpz0on2C4"
   },
   "outputs": [],
   "source": [
    "import tensorflow_model_optimization as tfmot\n",
    "\n",
    "# method to quantize a Keras model\n",
    "quantize_model = tfmot.quantization.keras.quantize_model\n",
    "\n",
    "# Define the model architecture.\n",
    "model_to_quantize = model_builder()\n",
    "\n",
    "# Reinitialize weights with saved file\n",
    "model_to_quantize.load_weights(FILE_WEIGHTS)\n",
    "\n",
    "# Quantize the model\n",
    "q_aware_model = quantize_model(model_to_quantize)\n",
    "\n",
    "# `quantize_model` requires a recompile.\n",
    "q_aware_model.compile(optimizer='adam',\n",
    "              loss='sparse_categorical_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "q_aware_model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lmcaLaotZ7G7"
   },
   "source": [
    "You may have noticed a slight difference in the model summary above compared to the baseline model summary in the earlier sections. The total params count increased as expected because of the nodes added by the `quantize_model()` method.\n",
    "\n",
    "With that, you can now train the model. You will notice that the accuracy is a bit lower because the model is simulating the loss of precision. The training will take a bit longer if you want to achieve the same training accuracy as the earlier run. For this exercise though, we will keep to 1 epoch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "yl4jbjllomDw"
   },
   "outputs": [],
   "source": [
    "# Train the model\n",
    "q_aware_model.fit(train_images, train_labels, epochs=1, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "b_WAM2C4bWeC"
   },
   "source": [
    "You can then get the accuracy of the Keras model before and after quantizing the model. The accuracy is expected to be nearly identical because the model is trained to counter the effects of quantization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "J7rOuwM_ozI_"
   },
   "outputs": [],
   "source": [
    "# Reinitialize the dictionary\n",
    "ACCURACY = {}\n",
    "\n",
    "# Get the accuracy of the quantization aware trained model (not yet quantized)\n",
    "_, ACCURACY['quantization aware non-quantized'] = q_aware_model.evaluate(test_images, test_labels, verbose=0)\n",
    "print_metric(ACCURACY, 'test accuracy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6liE_Cp3rzAy"
   },
   "outputs": [],
   "source": [
    "# Convert and quantize the model.\n",
    "convert_tflite(q_aware_model, FILE_QAT_QUANTIZED, quantize=True)\n",
    "\n",
    "# Get the accuracy of the quantized model\n",
    "ACCURACY['quantization aware quantized'] = evaluate_tflite_model(FILE_QAT_QUANTIZED, test_images, test_labels)\n",
    "print_metric(ACCURACY, 'test accuracy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SwvaMflTYNgo"
   },
   "source": [
    "## Pruning\n",
    "\n",
    "Let's now move on to another technique for reducing model size: [Pruning](https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras). This process involves zeroing out insignificant (i.e. low magnitude) weights. The intuition is these weights do not contribute as much to making predictions so you can remove them and get the same result. Making the weights sparse helps in compressing the model more efficiently and you will see that in this section."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LdlFujrJbzV7"
   },
   "source": [
    "The Tensorflow Model Optimization Toolkit again has a convenience method for this. The [prune_low_magnitude()](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/prune_low_magnitude) method puts wrappers in a Keras model so it can be pruned during training. You will pass in the baseline model that you already trained earlier. You will notice that the model summary show increased params because of the wrapper layers added by the pruning method.\n",
    "\n",
    "You can set how the pruning is done during training. Below, you will use [PolynomialDecay](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/PolynomialDecay) to indicate how the sparsity ramps up with each step. Another option available in the library is [Constant Sparsity](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/ConstantSparsity)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "TpqizJsKYPBA"
   },
   "outputs": [],
   "source": [
    "# Get the pruning method\n",
    "prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude\n",
    "\n",
    "# Compute end step to finish pruning after 2 epochs.\n",
    "batch_size = 128\n",
    "epochs = 2\n",
    "validation_split = 0.1 # 10% of training set will be used for validation set. \n",
    "\n",
    "num_images = train_images.shape[0] * (1 - validation_split)\n",
    "end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs\n",
    "\n",
    "# Define pruning schedule.\n",
    "pruning_params = {\n",
    "      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,\n",
    "                                                               final_sparsity=0.80,\n",
    "                                                               begin_step=0,\n",
    "                                                               end_step=end_step)\n",
    "}\n",
    "\n",
    "# Pass in the trained baseline model\n",
    "model_for_pruning = prune_low_magnitude(baseline_model, **pruning_params)\n",
    "\n",
    "# `prune_low_magnitude` requires a recompile.\n",
    "model_for_pruning.compile(optimizer='adam',\n",
    "              loss='sparse_categorical_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "model_for_pruning.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qgmHaZI6fip_"
   },
   "source": [
    "You can also peek at the weights of one of the layers in your model. After pruning, you will notice that many of these will be zeroed out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "y5ekdEBigB5l"
   },
   "outputs": [],
   "source": [
    "# Preview model weights\n",
    "model_for_pruning.weights[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0XFwMRqpgbr0"
   },
   "source": [
    "With that, you can now start re-training the model. Take note that the [UpdatePruningStep()](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/UpdatePruningStep) callback is required."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DUCz6PL371Bx"
   },
   "outputs": [],
   "source": [
    "# Callback to update pruning wrappers at each step\n",
    "callbacks = [\n",
    "  tfmot.sparsity.keras.UpdatePruningStep(),\n",
    "]\n",
    "\n",
    "# Train and prune the model\n",
    "model_for_pruning.fit(train_images, train_labels,\n",
    "                  epochs=epochs, validation_split=validation_split,\n",
    "                  callbacks=callbacks)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rEExgy4hhXP-"
   },
   "source": [
    "Now see how the weights in the same layer looks like after pruning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "TOK4TidJhXpT"
   },
   "outputs": [],
   "source": [
    "# Preview model weights\n",
    "model_for_pruning.weights[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "o5ckfDHLhhub"
   },
   "source": [
    "After pruning, you can remove the wrapper layers to have the same layers and params as the baseline model. You can do that with the [strip_pruning()](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/strip_pruning) method as shown below. You will do this so you can save the model and also export to TF Lite format just like in the previous sections."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PbfLhZv68vwc"
   },
   "outputs": [],
   "source": [
    "# Remove pruning wrappers\n",
    "model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)\n",
    "model_for_export.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KtbPlo-kj9Ku"
   },
   "source": [
    "You will see the same model weights but the index is different because the wrappers were removed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SG6-aF9yiraG"
   },
   "outputs": [],
   "source": [
    "# Preview model weights (index 1 earlier is now 0 because pruning wrappers were removed)\n",
    "model_for_export.weights[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZR94MYxLkHfn"
   },
   "source": [
    "You will notice below that the pruned model will have the same file size as the baseline_model when saved as H5. This is to be expected. The improvement will be noticeable when you compress the model as will be shown in the cell after this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CjjDMqJCTjqz"
   },
   "outputs": [],
   "source": [
    "# Save Keras model\n",
    "model_for_export.save(FILE_PRUNED_MODEL_H5, include_optimizer=False)\n",
    "\n",
    "# Get uncompressed model size of baseline and pruned models\n",
    "MODEL_SIZE = {}\n",
    "MODEL_SIZE['baseline h5'] = os.path.getsize(FILE_NON_QUANTIZED_H5)\n",
    "MODEL_SIZE['pruned non quantized h5'] = os.path.getsize(FILE_PRUNED_MODEL_H5)\n",
    "\n",
    "print_metric(MODEL_SIZE, 'model_size in bytes')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tCEfa-LRleT_"
   },
   "source": [
    "You will use the `get_gzipped_model_size()` helper function in the `Utilities` to compress the models and get its resulting file size. You will notice that the pruned model is about 3 times smaller. This is because of the sparse weights generated by the pruning process. The zeros can be compressed much more efficiently than the low magnitude weights before pruning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VWQ_AgiX_yiP"
   },
   "outputs": [],
   "source": [
    "# Get compressed size of baseline and pruned models\n",
    "MODEL_SIZE = {}\n",
    "MODEL_SIZE['baseline h5'] = get_gzipped_model_size(FILE_NON_QUANTIZED_H5)\n",
    "MODEL_SIZE['pruned non quantized h5'] = get_gzipped_model_size(FILE_PRUNED_MODEL_H5)\n",
    "\n",
    "print_metric(MODEL_SIZE, \"gzipped model size in bytes\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uByyx0L3mlYc"
   },
   "source": [
    "You can make the model even more lightweight by quantizing the pruned model. This achieves around 10X reduction in compressed model size as compared to the baseline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qIY6n9XWCvt5"
   },
   "outputs": [],
   "source": [
    "# Convert and quantize the pruned model.\n",
    "pruned_quantized_tflite = convert_tflite(model_for_export, FILE_PRUNED_QUANTIZED_TFLITE, quantize=True)\n",
    "\n",
    "# Compress and get the model size\n",
    "MODEL_SIZE['pruned quantized tflite'] = get_gzipped_model_size(FILE_PRUNED_QUANTIZED_TFLITE)\n",
    "print_metric(MODEL_SIZE, \"gzipped model size in bytes\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "v4ytiH3ynIid"
   },
   "source": [
    "As expected, the TF Lite model's accuracy will also be close to the Keras model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PZBAdJmuWN0A"
   },
   "outputs": [],
   "source": [
    "# Get accuracy of pruned Keras and TF Lite models\n",
    "ACCURACY = {}\n",
    "\n",
    "_, ACCURACY['pruned model h5'] = model_for_pruning.evaluate(test_images, test_labels)\n",
    "ACCURACY['pruned and quantized tflite'] = evaluate_tflite_model(FILE_PRUNED_QUANTIZED_TFLITE, test_images, test_labels)\n",
    "\n",
    "print_metric(ACCURACY, 'accuracy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CpM7t_nGokcz"
   },
   "source": [
    "## Wrap Up\n",
    "\n",
    "In this notebook, you practiced several techniques in optimizing your models for mobile and embedded applications. You used quantization to reduce floating point representations into integer, then used pruning to make the weights sparse for efficient model compression. These make your models lightweight for efficient transport and storage without sacrificing model accuracy. Try this in your own models and see what performance you get. For more information, here are a few other resources:\n",
    "\n",
    "* [Post Training Quantization Guide](https://www.tensorflow.org/lite/performance/post_training_quantization)\n",
    "* [Quantization Aware Training Comprehensive Guide](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide)\n",
    "* [Pruning Comprehensive Guide](https://www.tensorflow.org/model_optimization/guide/pruning/comprehensive_guide)\n",
    "\n",
    "**Congratulations and enjoy the rest of the course!**"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "C3_W2_Lab_3_Quantization_and_Pruning.ipynb",
   "private_outputs": true,
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
